/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file matmul_macro_v200_impl.h
 * \brief
 */
#ifndef IMPL_MATMUL_MATMUL_MACRO_V200_IMPL_H
#define IMPL_MATMUL_MATMUL_MACRO_V200_IMPL_H

#include "kernel_operator.h"
#include "matmul_macro_utils.h"

namespace matmul {
using namespace AscendC;

// ===========mad template=================/
// Cmatrix type, Amatrix type, Bmatrix type, L0C_using_uniflag, L0C_using_hset
template <typename C_T, typename A_T, typename B_T, uint16_t UNIFLAG_EN = 0, uint16_t L0AB_USING_HSET = 0>
class MacroMatmulV200 {
public:
    inline __aicore__ MacroMatmulV200(){};
    inline __aicore__ ~MacroMatmulV200();
    // addr
    uint64_t l0APing_ = L0A_PING_D;
    uint64_t l0APong_ = L0A_PONG_D;
    uint64_t l0BPing_ = L0B_PING_D;
    uint64_t l0BPong_ = L0B_PONG_D;
#ifdef ASCENDC_CPU_DEBUG
    uint64_t pA_;
    uint64_t pB_;
#endif
    // args
    uint64_t useL0PingPong_;
    uint16_t sAL1M_;
    uint16_t sAL1K_;
    uint16_t sAL1MOffset_;
    uint16_t sAL1KOffset_;
    uint16_t sBL1N_;
    uint16_t sBL1K_;
    uint16_t sBL1NOffset_;
    uint16_t sBL1KOffset_;
    uint16_t sMadM_;
    uint16_t sMadN_;
    uint16_t sMadK_;
    uint16_t sMad0K_;
    uint16_t sL0cInit_; // 0; normal  1:init
    uint16_t sL0cLast_; // 0; normal  1:last
    // feature map
    uint16_t sFmH_;
    uint16_t sFmC_;
    uint16_t sFmPadL_;
    uint16_t sFmPadR_;
    uint16_t sFmPadT_;
    uint16_t sFmPadD_;
    // state
    uint16_t ssAl0PingPongFlag_;
    uint16_t ssBl0PingPongFlag_;
    // instance args
    // 0:format(M, K)
    // 1:format(K, M), 需要设置transpose
    uint16_t ssAmatrixTranspose_;
    // 0:format(K, N), 用load3dv2搬运
    // 1:format(N, K), 用load2d搬运
    uint16_t ssBmatrixTranspose_;
    // 0: bias
    // 1: no bias
    uint16_t biasType_;
    uint16_t hwK0_;
    uint16_t typeSize_;
    uint16_t isGemv_;
    uint16_t isScalar_;
    A_T aScalar_;
    A_T bScalar_;
    event_t eventIdMToMte1Ping_;
    event_t eventIdMToMte1Pong_;
    // tpipe
    TPipe* tpipe_;
    TBuf<TPosition::A2> l0aBuf_;
    TBuf<TPosition::B2> l0bBuf_;

    inline __aicore__ void Init();
    inline __aicore__ void Compute(const LocalTensor<A_T> &l1AMatrix, const LocalTensor<B_T> &l1BMatrix,
        LocalTensor<C_T> &cMatrix);

private:
    inline __aicore__ void LoadL12L0A(uint64_t aPoskPtr, uint16_t usedK,
        const LocalTensor<A_T> &l1A, LocalTensor<A_T> &l0A);
    inline __aicore__ void LoadL12L0B(uint64_t kInner, uint16_t kC0, uint16_t kC0Tail,
        const LocalTensor<B_T> &l1B, LocalTensor<B_T> &l0B);
    inline __aicore__ void MmadMacro(const LocalTensor<A_T> &l0A, const LocalTensor<B_T> &l0B,
        LocalTensor<C_T> &cMatrix, uint16_t mmadK, bool isBias);
};

template <typename C_T, typename A_T, typename B_T, uint16_t UNIFLAG_EN, uint16_t L0AB_USING_HSET>
inline __aicore__ MacroMatmulV200<C_T, A_T, B_T, UNIFLAG_EN, L0AB_USING_HSET>::~MacroMatmulV200()
{
#ifdef ASCENDC_CPU_DEBUG
    free((__ca__ A_T *)pA_);
    free((__cb__ B_T *)pB_);
#endif
    WaitFlag<HardEvent::M_MTE1>(eventIdMToMte1Ping_);
    WaitFlag<HardEvent::M_MTE1>(eventIdMToMte1Pong_);
    GetTPipePtr()->ReleaseEventID<HardEvent::M_MTE1>(eventIdMToMte1Ping_);
    GetTPipePtr()->ReleaseEventID<HardEvent::M_MTE1>(eventIdMToMte1Pong_);
}

template <typename C_T, typename A_T, typename B_T, uint16_t UNIFLAG_EN, uint16_t L0AB_USING_HSET>
inline __aicore__ void MacroMatmulV200<C_T, A_T, B_T, UNIFLAG_EN, L0AB_USING_HSET>::MmadMacro(
    const LocalTensor<A_T> &l0A, const LocalTensor<B_T> &l0B, LocalTensor<C_T> &cMatrix,
    uint16_t mmadK, bool isBias)
{
    uint16_t madM = sMadM_;
    if (isGemv_) {
        madM = 1;
    } else if (madM == 1) {
        madM = 16;
    }

    MmadParams mmadParams;
    mmadParams.m = madM;
    mmadParams.k = mmadK;
    mmadParams.n = sMadN_;
    mmadParams.cmatrixInitVal = isBias;
    Mmad(cMatrix, l0A, l0B, mmadParams);

    if ((madM / ALIGN_NUM) * (sMadN_ / ALIGN_NUM) < 10) {
        PipeBarrier<PIPE_M>();
    }
}

template <typename C_T, typename A_T, typename B_T, uint16_t UNIFLAG_EN, uint16_t L0AB_USING_HSET>
inline __aicore__ void MacroMatmulV200<C_T, A_T, B_T, UNIFLAG_EN, L0AB_USING_HSET>::LoadL12L0A(uint64_t aPoskPtr,
    uint16_t usedK, const LocalTensor<A_T> &l1A, LocalTensor<A_T> &l0A)
{
    if (ssAmatrixTranspose_ > 0) {
        uint16_t wAlign = CeilAlign(sAL1K_, HW_M0);
        Load3DSetFMatrixCal(sFmH_, wAlign, padList);
    } else {
        // fmatrix w should be 16 aligned
        uint16_t wAlign = CeilAlign(sAL1M_, HW_M0);
        Load3DSetFMatrixCal(sFmH_, wAlign, padList);
    }
    if (isGemv_) {
        if (isScalar_) {
            ASSERT(sMadM_ == 1);
            InitConstValueParams initConstValueParams {(uint16_t)((ConstCeil(sMadK_, BLOCK_CUBE * hwK0_) << 16) | 1),
                                                       0, 0, aScalar_};
            InitConstValue(l0A, initConstValueParams);
            return;
        }
        int32_t fracSize = BYTE_PER_FRACTAL / sizeof(A_T);
        int32_t repeat = CeilDiv(usedK, fracSize);
        // aPoskPtr is unit of element
        LoadData2dParams loadDataParams;
        loadDataParams.repeatTimes = repeat;
        loadDataParams.srcStride = 1;
        loadDataParams.dstGap = 0;
        loadDataParams.ifTranspose = 0;
        LoadData(l0A[0], l1A[aPoskPtr], loadDataParams);
        return;
    }
    if (ssAmatrixTranspose_ > 0) {
        // format(K, M), K, M need to be 16 aligned for f32
        uint16_t madMAlign = CeilAlign(sMadM_, ALIGN_NUM);
        uint16_t usedKAlign = CeilAlign(usedK, HW_M0);
        uint16_t sAL1MAlign = CeilAlign(sAL1M_, ALIGN_NUM);
        // K_axis is m direction, and M_axis is k direction in load3d intrin
        LoadData3DParamsV2Pro loadData3DV2;
        loadData3DV2.channelSize = sAL1MAlign;
        loadData3DV2.extConfig = ((uint64_t)aPoskPtr << 48) | ((uint64_t)sAL1MOffset_ << 32) |
                                 ((uint64_t)usedKAlign << 16) | (uint64_t)madMAlign;
        loadData3DV2.enTranspose = true;
        LoadData<A_T>(l0A[0], l1A[0], loadData3DV2);
    } else {
        // format(M, K), K_axis is k direction, and M_axis is m direction in load3d intrin
        uint16_t madMAlign = CeilAlign(sMadM_, HW_M0);
        // k direction need to be 8 aligned for f32
        uint16_t usedKAlign = CeilAlign(usedK, hwK0_);
        uint16_t sAL1KAlign = CeilAlign(sAL1K_, hwK0_);
        LoadData3DParamsV2Pro loadData3DV2;
        loadData3DV2.channelSize = sAL1KAlign;
        loadData3DV2.extConfig = ((uint64_t)sAL1MOffset_ << 48) | ((uint64_t)aPoskPtr << 32) |
                                 ((uint64_t)madMAlign << 16) | (uint64_t)usedKAlign;
        LoadData<A_T>(l0A[0], l1A[0], loadData3DV2);
    }
}

template <typename C_T, typename A_T, typename B_T, uint16_t UNIFLAG_EN, uint16_t L0AB_USING_HSET>
inline __aicore__ void MacroMatmulV200<C_T, A_T, B_T, UNIFLAG_EN, L0AB_USING_HSET>::LoadL12L0B(uint64_t kInner,
    uint16_t kC0, uint16_t kC0Tail, const LocalTensor<B_T> &l1B, LocalTensor<B_T> &l0B)
{
    if (ssBmatrixTranspose_ < 1) {
        uint16_t wAlign = CeilAlign(sBL1K_, HW_M0);
        Load3DSetFMatrixCal(sFmH_, wAlign, padList);
    } else {
        uint16_t wAlign = CeilAlign(sBL1N_, HW_M0);
        Load3DSetFMatrixCal(sFmH_, wAlign, padList);
    }
    bool isTail = kC0Tail != 0;
    uint16_t nFraC0 = CeilDiv(sMadN_, HW_N0);
    if (ssBmatrixTranspose_ > 0) {
        // SET LOAD2D parameters , loop axis: K or M, or 1
        // k is hwK0_ aligned for f32
        uint16_t l0bLoop = 1;
        uint64_t l0bSrcAddrStride = 0;
        uint64_t l0bDstAddrStride = 0;
        uint8_t l0bRepeat = kC0 * nFraC0;
        uint16_t l0bSrcstride = 1;
        uint16_t l0bDststride = 0;

        if (nFraC0 * HW_N0 == sBL1N_) {
            l0bLoop = 1;            // loop=1
            if (isTail) {
                l0bRepeat = kC0Tail * nFraC0;
            }
        } else if (nFraC0 >= kC0) { // LOOP is K  and repeat is n axis
            l0bLoop = isTail ? kC0Tail : kC0;
            l0bSrcAddrStride = sBL1N_ * hwK0_ * typeSize_;
            l0bDstAddrStride = nFraC0 * HW_N0 * hwK0_ * typeSize_;
            l0bRepeat = nFraC0;

            l0bSrcstride = 1;
            l0bDststride = 0;
        } else { // LOOP is N  and repeat is K axis
            l0bLoop = nFraC0;
            l0bSrcAddrStride = HW_N0 * hwK0_ * typeSize_;
            l0bDstAddrStride = HW_N0 * hwK0_ * typeSize_;
            l0bRepeat = isTail ? kC0Tail : kC0;

            l0bSrcstride = (sBL1N_ + HW_N0 - 1) / HW_N0;
            l0bDststride = nFraC0 - 1;
        }

        // use load2d for L1_2_L0B
        LoadData2dParams loadDataParams;
        loadDataParams.repeatTimes = l0bRepeat;
        loadDataParams.srcStride = l0bSrcstride;
        loadDataParams.dstGap = l0bDststride;
        loadDataParams.ifTranspose = 0;
        uint64_t l1bOffset = sBL1NOffset_ * hwK0_ + sBL1KOffset_ * sBL1N_ +
            kInner * kC0 * hwK0_ * sBL1N_;
        uint64_t l0bOffset = 0;
        for (uint64_t i = 0; i < l0bLoop; i++) {
            LoadData(l0B[l0bOffset], l1B[l1bOffset], loadDataParams);
            l1bOffset += (l0bSrcAddrStride / typeSize_);
            l0bOffset += (l0bDstAddrStride / typeSize_);
        }
    } else {
        // use load3dv2 for L1_2_L0B
        // n_axis is K direction, need to be 16 aligned
        uint16_t kAlign = isTail ? nFraC0 * HW_N0 : CeilAlign(sMadN_, hwK0_);
        uint16_t mPos = sBL1KOffset_ + kInner * sMad0K_;
        // channel size need to be 16 aligned
        uint16_t cAlign = isTail ? static_cast<uint16_t>(sBL1N_) : CeilAlign(sBL1N_ + sBL1NOffset_, ALIGN_NUM);
        // k_axis is M direction, need to be HW_M0 aligned
        uint16_t mAlign = isTail ? kC0Tail * hwK0_ : CeilAlign(sMad0K_, HW_M0);

        // k direction need to be 8 aligned for f32
        // StepN need to be aligned
        LoadData3DParamsV2Pro loadData3DV2;
        loadData3DV2.channelSize = cAlign;
        loadData3DV2.extConfig = ((uint64_t)mPos << 48) | ((uint64_t)sBL1NOffset_ << 32) |
                                ((uint64_t)mAlign << 16) | (uint64_t)kAlign;
        LoadData<B_T>(l0B[0], l1B[0], loadData3DV2);
    }
}

// initialization
template <typename C_T, typename A_T, typename B_T, uint16_t UNIFLAG_EN, uint16_t L0AB_USING_HSET>
inline __aicore__ void MacroMatmulV200<C_T, A_T, B_T, UNIFLAG_EN, L0AB_USING_HSET>::Init()
{
#ifdef ASCENDC_CPU_DEBUG
    // allocate 64K L0A space for cpu debug
    pA_ = (uint64_t)((__ca__ A_T *)malloc(L0AUF_SIZE));
    // allocate 64K L0B space for cpu debug
    pB_ = (uint64_t)((__cb__ B_T *)malloc(L0BUF_SIZE));
    l0APing_ += pA_;
    l0APong_ += pA_;
    l0BPing_ += pB_;
    l0BPong_ += pB_;
#endif
    ssAl0PingPongFlag_ = 0;
    ssBl0PingPongFlag_ = 0;

    ssAmatrixTranspose_ = 0;
    ssBmatrixTranspose_ = 0;
    biasType_ = 0;
    isGemv_ = 0;
    isScalar_ = 0;
    typeSize_ = sizeof(A_T);
    if constexpr (IsSameType<C_T, float>::value && sizeof(A_T) == sizeof(half)) {
        hwK0_ = 16;
    } else if constexpr (IsSameType<C_T, float>::value && IsSameType<A_T, float>::value) {
        hwK0_ = 8;
    } else {
        hwK0_ = 32;
    }

    sL0cInit_ = 1;
    sL0cLast_ = 0;
    sFmH_ = 1;
    sFmPadL_ = 0;
    sFmPadR_ = 0;
    sFmPadT_ = 0;
    sFmPadD_ = 0;

    tpipe_->InitBuffer(l0aBuf_, L0AUF_SIZE);
    tpipe_->InitBuffer(l0bBuf_, L0BUF_SIZE);
    eventIdMToMte1Ping_ = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::M_MTE1>());
    eventIdMToMte1Pong_ = static_cast<event_t>(GetTPipePtr()->AllocEventID<HardEvent::M_MTE1>());
    SetFlag<HardEvent::M_MTE1>(eventIdMToMte1Ping_);
    SetFlag<HardEvent::M_MTE1>(eventIdMToMte1Pong_);
}

template <typename C_T, typename A_T, typename B_T, uint16_t UNIFLAG_EN, uint16_t L0AB_USING_HSET>
inline __aicore__ void MacroMatmulV200<C_T, A_T, B_T, UNIFLAG_EN, L0AB_USING_HSET>::Compute(
    const LocalTensor<A_T> &l1AMatrix, const LocalTensor<B_T> &l1BMatrix, LocalTensor<C_T> &cMatrix)
{
    uint64_t kC0 = sMad0K_ / hwK0_;
    uint64_t kLoop = sMadK_ / sMad0K_;       // sMad0K_循环的次数
    uint64_t kTail = sMadK_ - kLoop * sMad0K_;
    uint16_t sMad0KAlign = CeilAlign(sMad0K_, hwK0_);
    uint16_t kC0Norm = sMad0KAlign / hwK0_;

    LocalTensor<A_T> l0a;
    LocalTensor<B_T> l0b;
    event_t eventIDVToM = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::V_M));
    SetFlag<HardEvent::V_M>(eventIDVToM);
    WaitFlag<HardEvent::V_M>(eventIDVToM);
    for (uint64_t kInner = 0; kInner < kLoop; kInner++) {
        l0a = l0aBuf_.Get<A_T>();
        l0b = l0bBuf_.Get<B_T>();
        if ((ssAl0PingPongFlag_ & 0x1) != 0) {
            l0a = l0a[L0AUF_SIZE / 2 / sizeof(A_T)];
            l0b = l0b[L0BUF_SIZE / 2 / sizeof(B_T)];
        }
        event_t eventIdMToMte1PingPong = (ssAl0PingPongFlag_ & 0x1) ? eventIdMToMte1Pong_ : eventIdMToMte1Ping_;
        WaitFlag<HardEvent::M_MTE1>(eventIdMToMte1PingPong);

        // load L0A
        uint64_t aPoskPtr = kInner * kC0 * hwK0_ + sAL1KOffset_;
        LoadL12L0A(aPoskPtr, sMad0K_, l1AMatrix, l0a);
        // load L0B
        LoadL12L0B(kInner, kC0Norm, 0, l1BMatrix, l0b);
        SetFlag<HardEvent::MTE1_M>(ssAl0PingPongFlag_ & 0x1);
        WaitFlag<HardEvent::MTE1_M>(ssAl0PingPongFlag_ & 0x1);

        // MAD
        bool biasType = (kInner == 0) && biasType_;
        MmadMacro(l0a, l0b, cMatrix, sMad0K_, biasType);
        SetFlag<HardEvent::M_MTE1>(eventIdMToMte1PingPong);

        // update pingpong flag
        ssAl0PingPongFlag_ += useL0PingPong_;
        ssBl0PingPongFlag_ += useL0PingPong_;
    }
    // k  tail
    if (kTail != 0) {
        uint16_t madKC0 = CeilDiv(sMadK_, hwK0_);
        uint64_t kC0Tail = madKC0 - kLoop * kC0; // 尾块的循环次数, 单位是16

        l0a = l0aBuf_.Get<A_T>();
        l0b = l0bBuf_.Get<B_T>();
        if ((ssAl0PingPongFlag_ & 0x1) != 0) {
            l0a = l0a[L0AUF_SIZE / 2 / sizeof(A_T)];
            l0b = l0b[L0BUF_SIZE / 2 / sizeof(B_T)];
        }
        event_t eventIdPingPong = (ssAl0PingPongFlag_ & 0x1) ? eventIdMToMte1Pong_ : eventIdMToMte1Ping_;
        WaitFlag<HardEvent::M_MTE1>(eventIdPingPong);
        uint16_t tailK = kC0Tail * hwK0_;
        uint64_t aPoskPtr = kLoop * kC0 * hwK0_ + sAL1KOffset_;
        // load L0A
        LoadL12L0A(aPoskPtr, tailK, l1AMatrix, l0a);
        // load L0B
        LoadL12L0B(kLoop, kC0, kC0Tail, l1BMatrix, l0b);

        SetFlag<HardEvent::MTE1_M>(EVENT_ID0);
        WaitFlag<HardEvent::MTE1_M>(EVENT_ID0);
        // MAD
        bool biasType = (kLoop == 0) && biasType_;
        MmadMacro(l0a, l0b, cMatrix, kTail, biasType);

        SetFlag<HardEvent::M_MTE1>(eventIdPingPong);
        ssAl0PingPongFlag_ += useL0PingPong_;
        ssBl0PingPongFlag_ += useL0PingPong_;
    }
}
} // namespace matmul
#endif
